#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
#####################################################
import copy

import torch, random
import torch.nn as nn
from torch.distributions.categorical import Categorical

from exps.utils import encoded_arch
from .genotypes import Structure

import time
import numpy as np
import math


class Controller(nn.Module):
    # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
    def __init__(
        self,
        edge2index,
        op_names,
        max_nodes,
        lstm_size=32,
        lstm_num_layers=2,
        tanh_constant=2.5,
        temperature=5.0,
    ):
        super(Controller, self).__init__()
        # assign the attributes
        self.max_nodes = max_nodes
        self.num_edge = len(edge2index)
        self.edge2index = edge2index
        self.num_ops = len(op_names)
        self.op_names = op_names
        self.lstm_size = lstm_size
        self.lstm_N = lstm_num_layers
        self.tanh_constant = tanh_constant
        self.temperature = temperature

        # create parameters
        self.register_parameter(
            "input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
        )
        self.w_lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.lstm_N,
        )
        self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
        self.w_pred = nn.Linear(self.lstm_size, self.num_ops)

        nn.init.uniform_(self.input_vars, -0.1, 0.1)
        nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
        nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
        nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
        nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)

    def convert_structure(self, _arch):
        genotypes = []
        for i in range(1, self.max_nodes):
            xlist = []
            for j in range(i):
                node_str = "{:}<-{:}".format(i, j)
                op_index = _arch[self.edge2index[node_str]]
                op_name = self.op_names[op_index]
                xlist.append((op_name, j))
            genotypes.append(tuple(xlist))
        return Structure(genotypes)

    def get_prob(self, actions_index):
        inputs, h0 = self.input_vars, None
        log_probs, entropys, sampled_arch = [], [], []
        probs = []

        x_probs = []
        for iedge in range(self.num_edge):
            outputs, h0 = self.w_lstm(inputs, h0)

            logits = self.w_pred(outputs)
            # print("1. logits: ", logits)
            logits = logits / self.temperature
            # print("2. logits: ", logits)
            logits = self.tanh_constant * torch.tanh(logits)
            # print("3. logits: ", logits)

            x_prob = torch.softmax(logits, dim=-1)
            # print("4. x_probs: ", x_prob)
            x_probs.append(np.round(x_prob.view(-1).tolist(), 2))

            # distribution
            op_distribution = Categorical(logits=logits)
            # op_index = op_distribution.sample()
            # sampled_arch.append(op_index.item())

            op_index = actions_index[iedge]
            op_index = op_index.unsqueeze(dim=-1).unsqueeze(dim=-1)  # op_index -> [[op_index]]
            sampled_arch.append(op_index.item())

            # print("get_porb:", "op_distribution:", op_distribution, "op_index:", op_index)
            op_log_prob = op_distribution.log_prob(op_index)

            op_prob = op_distribution.probs.squeeze()[op_index.item()].unsqueeze(-1).unsqueeze(-1)
            probs.append(op_prob.view(-1))

            log_probs.append(op_log_prob.view(-1))
            op_entropy = op_distribution.entropy()
            entropys.append(op_entropy.view(-1))

            # obtain the input embedding for the next step
            inputs = self.w_embd(op_index)

        return (
            x_probs,
            torch.sum(torch.cat(log_probs)),
            torch.sum(torch.cat(entropys)),
            self.convert_structure(sampled_arch),
        )

    def forward(self):
        inputs, h0 = self.input_vars, None
        log_probs, entropys, sampled_arch = [], [], []
        probs = []

        x_probs = []

        for iedge in range(self.num_edge):
            outputs, h0 = self.w_lstm(inputs, h0)

            logits = self.w_pred(outputs)
            # print("1. logits: ", logits)
            logits = logits / self.temperature
            # print("2. logits: ", logits)
            logits = self.tanh_constant * torch.tanh(logits)
            # print("3. logits: ", logits)

            x_prob = torch.softmax(logits, dim=-1)
            # print("4. x_probs: ", x_prob)
            x_probs.append(np.round(x_prob.view(-1).tolist(), 2))

            # distribution
            op_distribution = Categorical(logits=logits)
            # print("4. op_distribution:", op_distribution)
            op_index = op_distribution.sample()
            sampled_arch.append(op_index.item())

            op_log_prob = op_distribution.log_prob(op_index)

            op_prob = op_distribution.probs.squeeze()[op_index.item()].unsqueeze(-1).unsqueeze(-1)
            probs.append(op_prob.view(-1))

            log_probs.append(op_log_prob.view(-1))
            op_entropy = op_distribution.entropy()
            entropys.append(op_entropy.view(-1))

            # obtain the input embedding for the next step
            inputs = self.w_embd(op_index)

        return (
            x_probs,
            torch.sum(torch.cat(log_probs)),
            torch.sum(torch.cat(entropys)),
            self.convert_structure(sampled_arch),
            sampled_arch
        )


class ControllerTrainer(object):

    def __init__(self, controller, optimizer, predictor, nas_bench, args, log):
        self.controller = controller
        self.optimizer = optimizer
        self.nas_bench = nas_bench
        self.predictor = predictor

        self.log = log
        self.args = args

        self.buffer_unique_archs = {"true_info":[], "pred_info":[]}
        self.arch_buffer = []
        self.agent_buffer = []
        self.cur_total_time_costs = 0.0
        self.alter_flag = True

        self.topn = 5
        self.next_group = []

        self.acq_fn = None

    def set_predictor(self, predictor):
        self.predictor = predictor

    def set_acq_fn(self, acq_fn):
        self.acq_fn = acq_fn

    def get_unique_arch_by_random(self, sampler):
        sampled_arch = sampler()

        while sampled_arch.tostr() in self.arch_buffer:
            # print("func: get_unique_arch_by_random, Arise a repeated arch: {}".format(sampled_arch.tostr()))
            sampled_arch = sampler()

        return sampled_arch

    def get_unique_arch(self):
        probs, log_prob, entropy, sampled_arch, actions_index = self.controller()

        while sampled_arch.tostr() in self.arch_buffer:
            # self.log.info("Arise a repeated arch: {}".format(sampled_arch.tostr()))
            probs, log_prob, entropy, sampled_arch, actions_index = self.controller()

        return probs, log_prob, entropy, sampled_arch, actions_index

    def controller_sample(self, steps, baseline, is_predictor):

        decay = 0.95
        controller_entropy_weight = 0.0001

        # global current_total_costs
        for step in range(steps):

            probs, log_prob, entropy, sampled_arch, actions_index = self.get_unique_arch()

            if is_predictor:

                time_start = time.time()
                if not self.args.is_ensemble:
                    val_top1 = self.acq_fn.query([sampled_arch.tostr()])
                else:
                    val_top1 = self.acq_fn(sampled_arch.tostr())
                time_cost = time.time() - time_start
                self.buffer_unique_archs["pred_info"].append([sampled_arch.tostr(), val_top1])

            else:
                val_top1, _, time_cost = self.nas_bench.get_simul_train_epoch12_info(self.args.dataset, sampled_arch)
                self.buffer_unique_archs["true_info"].append([sampled_arch.tostr(), val_top1])

            self.arch_buffer.append(sampled_arch.tostr())

            val_top1 = torch.tensor(val_top1)
            reward = val_top1 + controller_entropy_weight * entropy

            if baseline is None:
                baseline = val_top1
            else:
                # baseline = prev_baseline - (1 - controller_bl_dec) * (
                #     prev_baseline - reward
                # )
                baseline = decay * baseline + (1 - decay) * reward

            self.log.info(
                "Is_predictor: {:}, Baseline: {:.3f}, Val_acc: {:.4f}, Reward:{:.4f}, Time_cost: {:4f}, Arch: {:}".format(
                    is_predictor, baseline.item(), val_top1, reward.item(), time_cost, sampled_arch.tostr()
                )
            )

            if torch.is_tensor(reward):
                reward = reward.tolist()
            if torch.is_tensor(log_prob):
                log_prob = log_prob.tolist()

            self.agent_buffer.append((actions_index, log_prob, reward))

            self.cur_total_time_costs += time_cost

        return baseline

    def cal_loss(self, log_p, log_old_p, reward, baseline):
        if self.args.update_agent_algo == "ppo":
            ratio = torch.exp(log_p - log_old_p)
            adv = reward - baseline
            clip_adv = torch.clamp(ratio, 1 - self.args.clip_ratio, 1 + self.args.clip_ratio) * adv
            policy_loss = -(torch.min(ratio * adv, clip_adv))

        elif self.args.update_agent_algo == "pg":
            policy_loss = -1 * log_p * (reward - baseline)
        else:
            raise ValueError("Invalid agent's algo : {:}".format(self.args.update_agent_algo))

        return policy_loss

    def train_policy(self, baseline):

        for i in range(self.args.update_policy_epochs):
            loss = 0
            for v in self.agent_buffer:
                actions_index, log_old_prob, reward = v

                actions_index = torch.as_tensor(actions_index)
                log_old_prob = torch.as_tensor(log_old_prob)
                reward = torch.as_tensor(reward)

                if torch.cuda.is_available():
                    actions_index = actions_index.cuda()
                    log_old_prob = log_old_prob.cuda()
                    reward = reward.cuda()

                # print("train_policy", "action_index:", actions_index, "log_old_prob:", log_old_prob, "reward:", reward)
                probs, log_prob, _, _ = self.controller.get_prob(actions_index)
                loss += self.cal_loss(log_prob, log_old_prob, reward, baseline.detach())

            loss /= len(self.agent_buffer)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def get_cur_best_arch(self):
        sorted_true_data = sorted(self.buffer_unique_archs["true_info"], key=lambda v: (v[1], v[0]), reverse=True)
        # choose the top1
        best_arch, best_val = sorted_true_data[0][0], sorted_true_data[0][1]
        # print("best_arch:", best_arch, "sorted_true_data[0]:", sorted_true_data[0])

        pred_best_val = 0.
        if len(self.buffer_unique_archs["pred_info"]) > 0:
            if len(self.next_group) == 0:
                sorted_pred_data = sorted(self.buffer_unique_archs["pred_info"], key=lambda v: (v[1], v[0]), reverse=True)
                top_n = sorted_pred_data[:self.topn]
                top_n.reverse()
                self.next_group = [arch_info for arch_info in top_n]

            best_pred_arch = self.next_group.pop()[0]

            # best_pred_arch = sorted_pred_data[:1][0][0]     # top1

            # print("dataset, best_pred_arch:", dataset, best_pred_arch,)
            true_info_archs = np.array(self.buffer_unique_archs["true_info"])[:, 0].tolist()
            if best_pred_arch in true_info_archs:
                idx = true_info_archs.index(best_pred_arch)
                pred_best_val = self.buffer_unique_archs["true_info"][idx][1]
            else:
                val_top1, _, time_cost = self.nas_bench.get_simul_train_epoch12_info(self.args.dataset, best_pred_arch)
                pred_best_val = val_top1
                self.buffer_unique_archs["true_info"].append([best_pred_arch, val_top1])
                self.cur_total_time_costs += time_cost

        if pred_best_val > best_val:
            pred_best_val, pred_best_test = self.nas_bench.get_simul_full_train_info(self.args.dataset, best_pred_arch, deterministic=True)
            self.log.info("From: {}, cur_best_arch: {}, cur_total_time_costs: {}, val_acc: {}, test_acc: {}".format("pred_info", best_pred_arch, self.cur_total_time_costs, pred_best_val, pred_best_test))
            return {"pred_info": [best_pred_arch, pred_best_val, pred_best_test]}
        else:
            best_val, best_test = self.nas_bench.get_simul_full_train_info(self.args.dataset, best_arch, deterministic=True)
            self.log.info("From: {}, cur_best_arch: {}, cur_total_time_costs: {}, val_acc: {}, test_acc: {}".format("true_info", best_arch, self.cur_total_time_costs, best_val, best_test))
            return {"true_info": [best_arch, best_val, best_test]}

    def train_controller(self, baseline, t_steps=20, steps=20):
        if self.cur_total_time_costs > self.args.time_budget:
            return baseline

        self.controller.train()
        self.controller.zero_grad()
        self.agent_buffer = []

        p_steps = 0
        if self.args.is_predictor == "True":
            p_steps = steps - t_steps

        # 真实数据.
        baseline = self.controller_sample(t_steps, baseline, False)

        # 预测数据.
        if self.args.is_predictor == "True":
            baseline = self.controller_sample(p_steps, baseline, True)

        # 分析出当前最好架构. key=来源, val={架构, val_acc, test_acc}
        _ = self.get_cur_best_arch()

        self.train_policy(baseline)

        return baseline

    def pre_train_controller(self):
        baseline = None

        for i in range(self.args.num_batch_per_epoch):
            if self.cur_total_time_costs > self.args.time_budget:
                break

            if self.args.predictor_mode == "None":
                t_steps = 20
            elif self.args.predictor_mode == "all":
                t_steps = 0
            elif self.args.predictor_mode == "fixed_k":
                t_steps = self.args.fixedk
            else:
                raise Exception("Invalid value:", self.args.predictor_mode)

            baseline = self.train_controller(baseline, t_steps=t_steps)

    def pre_train_controller_by_kl(self):
        baseline = None

        # 保存一个 batch 数据.
        temp_buffer = []
        for i in range(self.args.episodes):
            probs, log_prob, entropy, sampled_arch, actions_index = self.get_unique_arch()
            if torch.is_tensor(probs):
                probs = probs.tolist()

            temp_buffer.append((actions_index, probs))

        i = 0
        steps = 20
        while i < self.args.num_batch_per_epoch:
            if self.cur_total_time_costs > self.args.time_budget:
                break

            alpha = self.kl_check(temp_buffer)

            if alpha <= 0:
                alpha = 0
            elif alpha >= 1:
                alpha = 1

            adap_steps = int(alpha * steps)

            # alpha = self.kl_check(temp_buffer)
            # adap_steps = int(steps / (1 + math.exp(-alpha)))
            self.log.info("alpha: {}, steps: {}, true_steps: {}, pred_steps: {}".format(alpha, steps, adap_steps, steps - adap_steps))
            baseline = self.train_controller(baseline, t_steps=adap_steps)

            i += 1

    def kl_check(self, buffer):
        kl_loss = 0
        for v in buffer:
            actions_index, target = v

            actions_index = torch.as_tensor(actions_index)
            if torch.cuda.is_available():
                actions_index = actions_index.cuda()
            #     target = target.cuda()

            preds, _, _, _ = self.controller.get_prob(actions_index)

            cur_loss = 0
            for i in range(len(target)):
                cur_loss += torch.nn.functional.kl_div(torch.log(torch.tensor(preds[i]) + 1e-5 ), torch.tensor(target[i]), reduction='sum').item()

            kl_loss += cur_loss / len(target)

        avg_kl_loss = kl_loss / len(buffer)
        self.log.info("kl_loss: {}, avg_kl_loss: {}".format(kl_loss, avg_kl_loss))
        return kl_loss
